Generative Adversarial Networks (GAN)
Table of Contents
Source
If $P_{\text{model}}(x)$ can be estimated as close to $P_{\text{data}}(x)$, then data can be generated by sampling from $P_{\text{model}}(x)$.
In generative modeling, we'd like to train a network that models a distribution, such as a distribution over images.
GANs do not work with any explicit density function !
Instead, take game-theoretic approach
One way to judge the quality of the model is to sample from it.
Model to produce samples which are indistinguishable from the real data, as judged by a discriminator network whose job is to tell real from fake
$$\text{loss} = -y \log h(x) - (1-y) \log (1-h(x))$$
Non-Saturating Game when the generator is trained
Step 1: Fix $G$ and perform a gradient step to
Step 2: Fix $D$ and perform a gradient step to
$$\max_{G} E_{x \sim p_{z}(z)}\left[\log D(G(z))\right]$$
OR
Step 1: Fix $G$ and perform a gradient step to
$$\min_{D} E_{x \sim p_{\text{data}}(x)}\left[-\log D(x)\right] + E_{x \sim p_{z}(z)}\left[-\log (1-D(G(z)))\right]$$
Step 2: Fix $D$ and perform a gradient step to
$$\min_{G} E_{x \sim p_{z}(z)}\left[-\log D(G(z))\right]$$
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
mnist = tf.keras.datasets.mnist
(train_x, train_y), _ = mnist.load_data()
train_x = train_x[np.where(train_y == 2)]
train_x= train_x/255.0
train_x = train_x.reshape(-1, 784)
print('train_iamges :', train_x.shape)
generator = tf.keras.models.Sequential([
tf.keras.layers.Dense(units = 256, input_dim = 100, activation = 'relu'),
tf.keras.layers.Dense(units = 784, activation = 'sigmoid')
])
discriminator = tf.keras.models.Sequential([
tf.keras.layers.Dense(units = 256, input_dim = 784, activation = 'relu'),
tf.keras.layers.Dense(units = 1, activation = 'sigmoid'),
])
discriminator.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.0001),
loss = 'binary_crossentropy')
combined_input = tf.keras.layers.Input(shape = (100,))
generated = generator(combined_input)
discriminator.trainable = False
combined_output = discriminator(generated)
combined = tf.keras.models.Model(inputs = combined_input, outputs = combined_output)
combined.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.0002),
loss = 'binary_crossentropy')
def make_noise(samples):
return np.random.normal(0, 1, [samples, 100])
def plot_generated_images(generator, samples = 3):
noise = make_noise(samples)
generated_images = generator.predict(noise)
generated_images = generated_images.reshape(samples, 28, 28)
for i in range(samples):
plt.subplot(1, samples, i+1)
plt.imshow(generated_images[i], 'gray', interpolation = 'nearest')
plt.axis('off')
plt.tight_layout()
plt.show()
Step 1: Fix $G$ and perform a gradient step to
$$\min_{D} E_{x \sim p_{\text{data}}(x)}\left[-\log D(x)\right] + E_{x \sim p_{z}(z)}\left[-\log (1-D(G(z)))\right]$$Step 2: Fix $D$ and perform a gradient step to
$$\min_{G} E_{x \sim p_{z}(z)}\left[-\log D(G(z))\right]$$n_iter = 20000
batch_size = 100
fake = np.zeros(batch_size)
real = np.ones(batch_size)
for i in range(n_iter):
# Train Discriminator
noise = make_noise(batch_size)
generated_images = generator.predict(noise)
idx = np.random.randint(0, train_x.shape[0], batch_size)
real_images = train_x[idx]
D_loss_real = discriminator.train_on_batch(real_images, real)
D_loss_fake = discriminator.train_on_batch(generated_images, fake)
D_loss = D_loss_real + D_loss_fake
# Train Generator
noise = make_noise(batch_size)
G_loss = combined.train_on_batch(noise, real)
if i % 5000 == 0:
print('Discriminator Loss: ', D_loss)
print('Generator Loss: ', G_loss)
plot_generated_images(generator)
plot_generated_images(generator)
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255.0 , x_test/255.0
x_train, x_test = x_train.reshape(-1,784), x_test.reshape(-1,784)
y_train, y_test = y_train.reshape(-1, 1), y_test.reshape(-1, 1)
print('x_train: ', x_train.shape)
print('x_test: ', x_test.shape)
print('y_train: ', y_train.shape)
print('y_test: ', y_test.shape)
generator_model = tf.keras.models.Sequential([
tf.keras.layers.Dense(units = 256, input_dim = 138, activation = 'relu'),
tf.keras.layers.Dense(units = 784, activation = 'sigmoid')
])
noise = tf.keras.layers.Input(shape = (128,))
label = tf.keras.layers.Input(shape = (1,))
label_onehot = tf.keras.layers.CategoryEncoding(10, output_mode='one_hot')(label)
model_input = tf.keras.layers.concatenate([noise, label_onehot], axis = 1)
generated_image = generator_model(model_input)
generator = tf.keras.models.Model([noise, label], generated_image)
generator.summary()
discriminator_model = tf.keras.models.Sequential([
tf.keras.layers.Dense(units = 256,input_dim = 794, activation = 'relu'),
tf.keras.layers.Dense(units = 1, activation = 'sigmoid')
])
input_image = tf.keras.layers.Input(shape = (784,))
label = tf.keras.layers.Input(shape = (1,))
label_onehot = tf.keras.layers.CategoryEncoding(10, output_mode='one_hot')(label)
model_input = tf.keras.layers.concatenate([input_image, label_onehot], axis = 1)
validity = discriminator_model(model_input)
discriminator = tf.keras.models.Model([input_image, label], validity)
optim_d = tf.keras.optimizers.Adam(learning_rate = 0.0002)
discriminator.compile(loss = ['binary_crossentropy'],
optimizer = optim_d)
discriminator.summary()
noise = tf.keras.layers.Input(shape = (128,))
label = tf.keras.layers.Input(shape = (1,))
generated_image = generator([noise, label])
discriminator.trainable = False
validity = discriminator([generated_image, label])
combined = tf.keras.models.Model([noise, label], validity)
optim_combined = tf.keras.optimizers.Adam(learning_rate = 0.0002)
combined.compile(loss = ['binary_crossentropy'],
optimizer = optim_combined)
combined.summary()
def create_noise(samples):
return np.random.normal(0, 1, [samples, 128])
def plot_generated_images(generator):
noise = create_noise(10)
label = np.arange(0, 10).reshape(-1, 1)
generated_images = generator.predict([noise, label])
plt.figure(figsize = (90, 10))
for i in range(generated_images.shape[0]):
plt.subplot(1, 10, i + 1)
plt.imshow(generated_images[i].reshape((28, 28)), 'gray', interpolation = 'nearest')
plt.title('Digit: {}'.format(i), fontsize = 75)
plt.axis('off')
plt.show()
n_iter = 100000
batch_size = 100
valid = np.ones(batch_size)
fake = np.zeros(batch_size)
for i in range(n_iter):
# Train Discriminator
idx = np.random.randint(0, x_train.shape[0], batch_size)
real_images, labels = x_train[idx], y_train[idx]
noise = create_noise(batch_size)
generated_images = generator.predict([noise,labels])
d_loss_real = discriminator.train_on_batch([real_images, labels], valid)
d_loss_fake = discriminator.train_on_batch([generated_images, labels], fake)
d_loss = d_loss_real + d_loss_fake
# Train Generator
noise= create_noise(batch_size)
labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
g_loss = combined.train_on_batch([noise, labels], valid)
if i % 5000 == 0:
print('Discriminator Loss: ', d_loss)
print('Generator Loss: ', g_loss)
plot_generated_images(generator)
Generate fake MNIST images by CGAN
plot_generated_images(generator)
Ian Goodfellow, et al., "Generative Adversarial Nets" NIPS, 2014.
At NIPS 2016 by Ian Goodfellow
%%html
<center><iframe src="https://www.youtube.com/embed/9JpdAg6uMXs?rel=0"
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>
%%html
<center><iframe src="https://www.youtube.com/embed/5WoItGTWV54?rel=0"
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>
MIT by Aaron Courville
%%html
<center><iframe src="https://www.youtube.com/embed/JVb54xhEw6Y?rel=0"
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>
Univ. of Wateloo By Ali Ghodsi
%%html
<center><iframe src="https://www.youtube.com/embed/7G4_Y5rsvi8?rel=0"
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>
%%html
<center><iframe src="https://www.youtube.com/embed/odpjk7_tGY0?rel=0"
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')